{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 3: Train NicheTrans on SMA data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, time, datetime, warnings\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch.optim import lr_scheduler\n", "\n", "from model.nicheTrans_img import *\n", "from datasets.data_manager_SMA import SMA\n", "\n", "from utils.utils import *\n", "from utils.utils_training_SMA import train, test\n", "from utils.utils_dataloader import *\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize the args and fix seeds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==========\n", "Args:Namespace(dropout_rate=0.1, eval_step=1, gamma=0.1, gpu_devices='0', img_size=256, lr=0.0003, max_epoch=40, msi_path='/home/wzk/ST_data/SMA_data/Processed_data_v4', n_source=3000, n_target=50, noise_rate=0.2, optimizer='adam', path_img='/home/wzk/ST_data/SMA_data/Processed/patches', rna_path='/home/wzk/ST_data/SMA_data/Zhikang', save_dir='./log', seed=1, stepsize=20, test_batch=32, train_batch=32, weight_decay=0.0005, workers=4)\n", "==========\n" ] } ], "source": [ "%run ./args/args_SMA.py\n", "args = args\n", "\n", "set_seed(args.seed)\n", "os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices\n", "\n", "print(\"==========\\nArgs:{}\\n==========\".format(args))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize dataloaders and NicheTrans" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "------Calculating spatial graph...\n", "The graph contains 12134 edges, 3120 cells.\n", "3.8891 neighbors per cell on average.\n", "------Calculating spatial graph...\n", "The graph contains 24190 edges, 3120 cells.\n", "7.7532 neighbors per cell on average.\n", "------Calculating spatial graph...\n", "The graph contains 11322 edges, 2918 cells.\n", "3.8801 neighbors per cell on average.\n", "------Calculating spatial graph...\n", "The graph contains 22578 edges, 2918 cells.\n", "7.7375 neighbors per cell on average.\n", "------Calculating spatial graph...\n", "The graph contains 10360 edges, 2675 cells.\n", "3.8729 neighbors per cell on average.\n", "------Calculating spatial graph...\n", "The graph contains 20628 edges, 2675 cells.\n", "7.7114 neighbors per cell on average.\n", "=> SMA loaded\n", "Dataset statistics:\n", " ------------------------------\n", " subset | # num | \n", " ------------------------------\n", " train | Without filtering 6038 spots from 2 slides \n", " test | Without filtering 2675 spots from 1 slides\n", " train | After filting 6005 spots from 2 slides \n", " test | After filting 2655 spots from 1 slides\n", " ------------------------------\n" ] } ], "source": [ "# create the dataloaders\n", "dataset = SMA(path_img=args.path_img, rna_path=args.rna_path, msi_path=args.msi_path, n_top_genes=args.n_source, n_top_targets=args.n_target)\n", "trainloader, testloader = sma_dataloader(args, dataset)\n", "\n", "# create the model\n", "source_dimension, target_dimension = dataset.rna_length, dataset.msi_length\n", "model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=args.noise_rate, dropout_rate=args.dropout_rate)\n", "model = nn.DataParallel(model).cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize loss function (criterion) and optimizer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "criterion = nn.MSELoss()\n", "\n", "if args.optimizer == 'adam':\n", " optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n", "elif args.optimizer == 'SGD':\n", " optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)\n", "else:\n", " print('unexpected optimizer')\n", "\n", "if args.stepsize > 0:\n", " scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model training and testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start_time = time.time()\n", "\n", "for epoch in range(args.max_epoch):\n", " last_epoch = epoch + 1 == args.max_epoch\n", "\n", " print(\"==> Epoch {}/{}\".format(epoch+1, args.max_epoch))\n", " \n", " ################\n", " train(args, model, criterion, optimizer, trainloader, dataset.target_panel, use_img=False)\n", " if args.stepsize > 0: scheduler.step()\n", " \n", " if (epoch+1) % args.eval_step == 0:\n", " pearson = test(args, model, testloader, dataset.target_panel, last_epoch, use_img=False)\n", "\n", " if last_epoch==True:\n", " torch.save(model.state_dict(), 'NicheTrans_SMA_last.pth')\n", " ################\n", "\n", "elapsed = round(time.time() - start_time)\n", "elapsed = str(datetime.timedelta(seconds=elapsed))\n", "print(\"Finished. Total elapsed time (h:m:s): {}\".format(elapsed))" ] } ], "metadata": { "kernelspec": { "display_name": "pytorch_zk", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.19" } }, "nbformat": 4, "nbformat_minor": 2 }